/**
 * 
 */
package cn.edu.bjtu.model.core;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

import cn.edu.bjtu.api.ClassificationService;
import cn.edu.bjtu.api.NeuronNetwork;
import cn.edu.bjtu.core.ClassificationPair;
import cn.edu.bjtu.core.LocalFileModelReader;
import cn.edu.bjtu.core.LocalFileW2VModelReader;
import cn.edu.bjtu.core.LoggerSupport;
import cn.edu.bjtu.util.DatasetUtils;


/**
 * @author alex
 *
 */
public class TextClassificationModel extends LoggerSupport implements ClassificationService, Serializable{
	/**
	 * 
	 */
	private static final long serialVersionUID = 1L;
	int maxSentenceLength = 256;
	NeuronNetwork net;
	Word2VecModel w2v;
	
	public TextClassificationModel(int cpuNum) throws Exception {
		w2v = new Word2VecModel( new LocalFileW2VModelReader());
		try{
			Class.forName("org.apache.spark.SparkContext");
			net = new DirectedNetworkWrapper(new LocalFileModelReader());
		}catch(Exception ex){
			net =new Deep4jNetworkWrapper(new LocalFileModelReader(),cpuNum);
			logger.info("initial network function Deep4jNetworkWrapper is adopted,the function DirectedNetworkWrapper initila fail, the reason is {}",ex.getMessage());
		}
		net.load();
	}
	public TextClassificationModel() throws Exception {
		this(-1);
	}
	@Override
	public ClassificationPair[] classifyDocument(String doc) throws Exception {
		String[] docs = new String[]{doc};
		return classifyDocument(docs)[0];
	}

	@Override
	public ClassificationPair[][] classifyDocument(String[] docs) throws Exception {
		return classfifyDocuments(docs);
	}

	private ClassificationPair[][] classfifyDocuments(String[] docs) throws Exception{
		List<String> docList = new ArrayList<String>(docs.length);
		List<String> labelList = new ArrayList<>(docs.length);
		for(int i=0;i<docs.length;i++){
			docList.add(docs[i]);
			labelList.add("unknow");
		}
		//第三个参数一定要有,防止洗牌
		LabeledSentenceProvider lsp = new CollectionLabeledSentenceProvider(docList,labelList,null);
		List<INDArray[]> resList = new ArrayList<INDArray[]>(5);
		try{
			DataSetIterator dsi = DatasetUtils.getCNNDataSet(lsp, 32, maxSentenceLength, w2v.getW2v());
			while(dsi.hasNext()){
				DataSet ds = dsi.next();
				resList.add(net.output(ds.getFeatures()));
			}
		}finally{
		}
		ClassificationPair[][] result = new ClassificationPair[docs.length][];
		int k = 0;
		for(int i=0;i<resList.size();i++){
			INDArray ele = resList.get(i)[0];
			for(int j = 0;j<ele.rows();j++){
				result[k++] = TextCategorizationManager.get().getDesc(ele.getRow(j));
			}
		}
		return result;
	}
	

}
